"""
Implementation of learning rate schedules.

Taken and modified from PyTorch v1.0.1 source
https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py
"""

import argparse
from torch.optim import Optimizer
import math

LR_SCHEDULE = "lr_schedule"
LR_RANGE_TEST = "LRRangeTest"
ONE_CYCLE = "OneCycle"
WARMUP_LR = "WarmupLR"
WARMUP_DECAY_LR = "WarmupDecayLR"
VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR]

LR_RANGE_TEST_MIN_LR = "lr_range_test_min_lr"
LR_RANGE_TEST_STEP_RATE = "lr_range_test_step_rate"
LR_RANGE_TEST_STEP_SIZE = "lr_range_test_step_size"
LR_RANGE_TEST_STAIRCASE = "lr_range_test_staircase"

EDGE_VALUE = "edge_value"
MID_VALUE = "mid_value"

CYCLE_FIRST_STEP_SIZE = "cycle_first_step_size"
CYCLE_FIRST_STAIR_COUNT = "cycle_first_stair_count"
CYCLE_SECOND_STEP_SIZE = "cycle_second_step_size"
CYCLE_SECOND_STAIR_COUNT = "cycle_second_stair_count"
DECAY_STEP_SIZE = "decay_step_size"

CYCLE_MIN_LR = "cycle_min_lr"
CYCLE_MAX_LR = "cycle_max_lr"
DECAY_LR_RATE = "decay_lr_rate"

CYCLE_MIN_MOM = "cycle_min_mom"
CYCLE_MAX_MOM = "cycle_max_mom"
DECAY_MOM_RATE = "decay_mom_rate"

WARMUP_MIN_LR = "warmup_min_lr"
WARMUP_MAX_LR = "warmup_max_lr"
WARMUP_NUM_STEPS = "warmup_num_steps"
WARMUP_TYPE = "warmup_type"
WARMUP_LOG_RATE = "log"
WARMUP_LINEAR_RATE = "linear"

TOTAL_NUM_STEPS = "total_num_steps"


def add_tuning_arguments(parser):
    group = parser.add_argument_group(
        "Convergence Tuning", "Convergence tuning configurations"
    )

    # LR scheduler
    group.add_argument(
        "--lr_schedule", type=str, default=None, help="LR schedule for training."
    )

    # Learning rate range test
    group.add_argument(
        "--lr_range_test_min_lr", type=float, default=0.001, help="Starting lr value."
    )
    group.add_argument(
        "--lr_range_test_step_rate",
        type=float,
        default=1.0,
        help="scaling rate for LR range test.",
    )
    group.add_argument(
        "--lr_range_test_step_size",
        type=int,
        default=1000,
        help="training steps per LR change.",
    )
    group.add_argument(
        "--lr_range_test_staircase",
        type=bool,
        default=False,
        help="use staircase scaling for LR range test.",
    )

    # OneCycle schedule
    group.add_argument(
        "--cycle_first_step_size",
        type=int,
        default=1000,
        help="size of first step of 1Cycle schedule (training steps).",
    )
    group.add_argument(
        "--cycle_first_stair_count",
        type=int,
        default=-1,
        help="first stair count for 1Cycle schedule.",
    )
    group.add_argument(
        "--cycle_second_step_size",
        type=int,
        default=-1,
        help="size of second step of 1Cycle schedule (default first_step_size).",
    )
    group.add_argument(
        "--cycle_second_stair_count",
        type=int,
        default=-1,
        help="second stair count for 1Cycle schedule.",
    )
    group.add_argument(
        "--decay_step_size",
        type=int,
        default=1000,
        help="size of intervals for applying post cycle decay (training steps).",
    )

    # 1Cycle LR
    group.add_argument(
        "--cycle_min_lr", type=float, default=0.01, help="1Cycle LR lower bound."
    )
    group.add_argument(
        "--cycle_max_lr", type=float, default=0.1, help="1Cycle LR upper bound."
    )
    group.add_argument(
        "--decay_lr_rate", type=float, default=0.0, help="post cycle LR decay rate."
    )

    # 1Cycle Momentum
    group.add_argument(
        "--cycle_momentum",
        default=False,
        action="store_true",
        help="Enable 1Cycle momentum schedule.",
    )
    group.add_argument(
        "--cycle_min_mom", type=float, default=0.8, help="1Cycle momentum lower bound."
    )
    group.add_argument(
        "--cycle_max_mom", type=float, default=0.9, help="1Cycle momentum upper bound."
    )
    group.add_argument(
        "--decay_mom_rate",
        type=float,
        default=0.0,
        help="post cycle momentum decay rate.",
    )

    # Warmup LR
    group.add_argument(
        "--warmup_min_lr",
        type=float,
        default=0,
        help="WarmupLR minimum/initial LR value",
    )
    group.add_argument(
        "--warmup_max_lr", type=float, default=0.001, help="WarmupLR maximum LR value."
    )
    group.add_argument(
        "--warmup_num_steps",
        type=int,
        default=1000,
        help="WarmupLR step count for LR warmup.",
    )
    group.add_argument(
        "--warmup_type",
        type=str,
        default=WARMUP_LOG_RATE,
        help="WarmupLR increasing function during warmup",
    )
    return parser


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser = add_tuning_arguments(parser)

    lr_sched_args, unknown_args = parser.parse_known_args()
    return lr_sched_args, unknown_args


def override_lr_range_test_params(args, params):
    if hasattr(args, LR_RANGE_TEST_MIN_LR) and args.lr_range_test_min_lr is not None:
        params[LR_RANGE_TEST_MIN_LR] = args.lr_range_test_min_lr

    if (
        hasattr(args, LR_RANGE_TEST_STEP_RATE)
        and args.lr_range_test_step_rate is not None
    ):
        params[LR_RANGE_TEST_STEP_RATE] = args.lr_range_test_step_rate

    if (
        hasattr(args, LR_RANGE_TEST_STEP_SIZE)
        and args.lr_range_test_step_size is not None
    ):
        params[LR_RANGE_TEST_STEP_SIZE] = args.lr_range_test_step_size

    if (
        hasattr(args, LR_RANGE_TEST_STAIRCASE)
        and args.lr_range_test_staircase is not None
    ):
        params[LR_RANGE_TEST_STAIRCASE] = args.lr_range_test_staircase


def override_1cycle_params(args, params):
    if hasattr(args, CYCLE_FIRST_STEP_SIZE) and args.cycle_first_step_size is not None:
        params[CYCLE_FIRST_STEP_SIZE] = args.cycle_first_step_size

    if (
        hasattr(args, CYCLE_FIRST_STAIR_COUNT)
        and args.cycle_first_stair_count is not None
    ):
        params[CYCLE_FIRST_STAIR_COUNT] = args.cycle_first_stair_count

    if (
        hasattr(args, CYCLE_SECOND_STEP_SIZE)
        and args.cycle_second_step_size is not None
    ):
        params[CYCLE_SECOND_STEP_SIZE] = args.cycle_second_step_size

    if (
        hasattr(args, CYCLE_SECOND_STAIR_COUNT)
        and args.cycle_second_stair_count is not None
    ):
        params[CYCLE_SECOND_STAIR_COUNT] = args.cycle_second_stair_count

    if hasattr(args, DECAY_STEP_SIZE) and args.decay_step_size is not None:
        params[DECAY_STEP_SIZE] = args.decay_step_size

    # 1Cycle LR params
    if hasattr(args, CYCLE_MIN_LR) and args.cycle_min_lr is not None:
        params[CYCLE_MIN_LR] = args.cycle_min_lr

    if hasattr(args, CYCLE_MAX_LR) and args.cycle_max_lr is not None:
        params[CYCLE_MAX_LR] = args.cycle_max_lr

    if hasattr(args, DECAY_LR_RATE) and args.decay_lr_rate is not None:
        params[DECAY_LR_RATE] = args.decay_lr_rate

    # 1Cycle MOM params
    if hasattr(args, CYCLE_MIN_MOM) and args.cycle_min_mom is not None:
        params[CYCLE_MIN_MOM] = args.cycle_min_mom

    if hasattr(args, CYCLE_MAX_MOM) and args.cycle_max_mom is not None:
        params[CYCLE_MAX_MOM] = args.cycle_max_mom

    if hasattr(args, DECAY_MOM_RATE) and args.decay_mom_rate is not None:
        params[DECAY_MOM_RATE] = args.decay_mom_rate


def override_warmupLR_params(args, params):
    if hasattr(args, WARMUP_MIN_LR) and args.warmup_min_lr is not None:
        params[WARMUP_MIN_LR] = args.warmup_min_lr

    if hasattr(args, WARMUP_MAX_LR) and args.warmup_max_lr is not None:
        params[WARMUP_MAX_LR] = args.warmup_max_lr

    if hasattr(args, WARMUP_NUM_STEPS) and args.warmup_num_steps is not None:
        params[WARMUP_NUM_STEPS] = args.warmup_num_steps

    if hasattr(args, WARMUP_TYPE) and args.warmup_type is not None:
        params[WARMUP_TYPE] = args.warmup_type


def override_params(args, params):
    # LR range test params
    override_lr_range_test_params(args, params)

    # 1Cycle params
    override_1cycle_params(args, params)

    # WarmupLR params
    override_warmupLR_params(args, params)


def get_config_from_args(args):
    if not hasattr(args, LR_SCHEDULE) or args.lr_schedule is None:
        return None, "--{} not specified on command line".format(LR_SCHEDULE)

    if not args.lr_schedule in VALID_LR_SCHEDULES:
        return None, "{} is not supported LR schedule".format(args.lr_schedule)

    config = {}
    config["type"] = args.lr_schedule
    config["params"] = {}

    if args.lr_schedule == LR_RANGE_TEST:
        override_lr_range_test_params(args, config["params"])
    elif args.lr_schedule == ONE_CYCLE:
        override_1cycle_params(args, config["params"])
    else:
        override_warmupLR_params(args, config["params"])

    return config, None


def get_lr_from_config(config):
    if not "type" in config:
        return None, "LR schedule type not defined in config"

    if not "params" in config:
        return None, "LR schedule params not defined in config"

    lr_schedule = config["type"]
    lr_params = config["params"]

    if not lr_schedule in VALID_LR_SCHEDULES:
        return None, "{} is not a valid LR schedule".format(lr_schedule)

    if lr_schedule == LR_RANGE_TEST:
        return lr_params[LR_RANGE_TEST_MIN_LR], ""
    if lr_schedule == ONE_CYCLE:
        return lr_params[CYCLE_MAX_LR], ""
    # Warmup LR
    return lr_params[WARMUP_MAX_LR], ""


"""
Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped
optimizer to see if requirement is satisfied.
TODO: Looking under the hood to examine the wrapped optimizer is a hack that requires a better long-term fix.
"""


def get_torch_optimizer(optimizer):
    if isinstance(optimizer, Optimizer):
        return optimizer

    if hasattr(optimizer, "optimizer") and isinstance(optimizer.optimizer, Optimizer):
        return optimizer.optimizer

    raise TypeError(
        "{} is not a subclass of torch.optim.Optimizer".format(type(optimizer).__name__)
    )


class LRRangeTest(object):
    """Sets the learning rate of each parameter group according to
    learning rate range test (LRRT) policy. The policy increases learning
    rate starting from a base value with a constant frequency, as detailed in
    the paper `A disciplined approach to neural network hyper-parameters: Part1`_.

    LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to
    configure the LR boundaries for Cyclic LR schedules.

    LRRT changes the learning rate after every batch.
    `step` should be called after a batch has been used for training.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        lr_range_test_min_lr (float or list): Initial learning rate which is the
            lower boundary in the range test for each parameter group.
        lr_range_test_step_size (int): Interval of training steps to increase learning rate. Default: 2000
        lr_range_test_step_rate (float): Scaling rate for range test. Default: 1.0
        lr_range_test_staircase (bool): Scale in staircase fashion, rather than continuous. Default: False.
        last_batch_iteration (int): The index of the last batch. This parameter is used when
            resuming a training job. Since `step()` should be invoked after each
            batch instead of after each epoch, this number represents the total
            number of *batches* computed, not the total number of epochs computed.
            When last_batch_iteration=-1, the schedule is started from the beginning.
            Default: -1

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = LRRangeTest(optimizer)
        >>> data_loader = torch.utils.data.DataLoader(...)
        >>> for epoch in range(10):
        >>>     for batch in data_loader:
        >>>         train_batch(...)
        >>>         scheduler.step()

        _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay:
        https://arxiv.org/abs/1803.09820
    """

    def __init__(
        self,
        optimizer: Optimizer,
        lr_range_test_min_lr: float = 1e-3,
        lr_range_test_step_size: int = 2000,
        lr_range_test_step_rate: float = 1.0,
        lr_range_test_staircase: bool = False,
        last_batch_iteration: int = -1,
    ):

        self.optimizer = get_torch_optimizer(optimizer)

        if isinstance(lr_range_test_min_lr, list) or isinstance(
            lr_range_test_min_lr, tuple
        ):
            if len(lr_range_test_min_lr) != len(self.optimizer.param_groups):
                raise ValueError(
                    "expected {} lr_range_test_min_lr, got {}".format(
                        len(self.optimizer.param_groups), len(lr_range_test_min_lr)
                    )
                )
            self.min_lr = list(lr_range_test_min_lr)
        else:
            self.min_lr = [lr_range_test_min_lr] * len(self.optimizer.param_groups)

        self.step_size = lr_range_test_step_size
        self.step_rate = lr_range_test_step_rate
        self.last_batch_iteration = last_batch_iteration
        self.staircase = lr_range_test_staircase
        self.interval_fn = (
            self._staircase_interval
            if lr_range_test_staircase
            else self._continuous_interval
        )

        if last_batch_iteration == -1:
            self._update_optimizer(self.min_lr)

    def _staircase_interval(self):
        return math.floor(float(self.last_batch_iteration + 1) / self.step_size)

    def _continuous_interval(self):
        return float(self.last_batch_iteration + 1) / self.step_size

    def _get_increase(self):
        return 1 + self.step_rate * self.interval_fn()

    def get_lr(self):
        lr_increase = self._get_increase()
        return [
            lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr
        ]

    def get_last_lr(self):
        """Return last computed learning rate by current scheduler."""
        assert getattr(self, "_last_lr", None) is not None, "need to call step() first"
        return self._last_lr

    def _update_optimizer(self, group_lrs):
        for param_group, lr in zip(self.optimizer.param_groups, group_lrs):
            param_group["lr"] = lr

    def step(self, batch_iteration=None):
        if batch_iteration is None:
            batch_iteration = self.last_batch_iteration + 1
        self.last_batch_iteration = batch_iteration
        self._update_optimizer(self.get_lr())
        self._last_lr = [group["lr"] for group in self.optimizer.param_groups]

    def state_dict(self):
        return {"last_batch_iteration": self.last_batch_iteration}

    def load_state_dict(self, sd):
        self.last_batch_iteration = sd["last_batch_iteration"]


class OneCycle(object):
    """Sets the learning rate of each parameter group according to
    1Cycle learning rate policy (1CLR). 1CLR is a variation of the
    Cyclical Learning Rate (CLR) policy that involves one cycle followed by
    decay. The policy simultaneously cycles the learning rate (and momentum)
    between two boundaries with a constant frequency, as detailed in
    the paper `A disciplined approach to neural network hyper-parameters`_.

    1CLR policy changes the learning rate after every batch.
    `step` should be called after a batch has been used for training.

    This implementation was adapted from the github repo: `pytorch/pytorch`_

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        cycle_min_lr (float or list): Initial learning rate which is the
            lower boundary in the cycle for each parameter group.
        cycle_max_lr (float or list): Upper learning rate boundaries in the cycle
            for each parameter group. Functionally,
            it defines the cycle amplitude (cycle_max_lr - cycle_min_lr).
            The lr at any cycle is the sum of cycle_min_lr
            and some scaling of the amplitude; therefore
            cycle_max_lr may not actually be reached depending on
            scaling function.
        decay_lr_rate(float): Decay rate for learning rate. Default: 0.
        cycle_first_step_size (int): Number of training iterations in the
            increasing half of a cycle. Default: 2000
        cycle_second_step_size (int): Number of training iterations in the
            decreasing half of a cycle. If cycle_second_step_size is None,
            it is set to cycle_first_step_size. Default: None
        cycle_first_stair_count(int): Number of stairs in first half of cycle phase. This means
        lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
        cycle_second_stair_count(int): Number of stairs in second half of cycle phase. This means
        lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
        decay_step_size (int): Intervals for applying decay in decay phase. Default: 0, means no decay.
        cycle_momentum (bool): If ``True``, momentum is cycled inversely
            to learning rate between 'cycle_min_mom' and 'cycle_max_mom'.
            Default: True
        cycle_min_mom (float or list): Initial momentum which is the
            lower boundary in the cycle for each parameter group.
            Default: 0.8
        cycle_max_mom (float or list): Upper momentum boundaries in the cycle
            for each parameter group. Functionally,
            it defines the cycle amplitude (cycle_max_mom - cycle_min_mom).
            The momentum at any cycle is the difference of cycle_max_mom
            and some scaling of the amplitude; therefore
            cycle_min_mom may not actually be reached depending on
            scaling function. Default: 0.9
        decay_mom_rate (float): Decay rate for momentum. Default: 0.
        last_batch_iteration (int): The index of the last batch. This parameter is used when
            resuming a training job. Since `step()` should be invoked after each
            batch instead of after each epoch, this number represents the total
            number of *batches* computed, not the total number of epochs computed.
            When last_batch_iteration=-1, the schedule is started from the beginning.
            Default: -1

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = OneCycle(optimizer, 0.0001, 0.0010)
        >>> data_loader = torch.utils.data.DataLoader(...)
        >>> for epoch in range(10):
        >>>     for batch in data_loader:
        >>>         train_batch(...)
        >>>         scheduler.step()


    .. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820
    """

    def __init__(
        self,
        optimizer,
        cycle_min_lr,
        cycle_max_lr,
        decay_lr_rate=0.0,
        cycle_first_step_size=2000,
        cycle_second_step_size=None,
        cycle_first_stair_count=0,
        cycle_second_stair_count=None,
        decay_step_size=0,
        cycle_momentum=True,
        cycle_min_mom=0.8,
        cycle_max_mom=0.9,
        decay_mom_rate=0.0,
        last_batch_iteration=-1,
    ):

        self.optimizer = get_torch_optimizer(optimizer)

        # Initialize cycle shape
        self._initialize_cycle(
            cycle_first_step_size,
            cycle_second_step_size,
            cycle_first_stair_count,
            cycle_second_stair_count,
            decay_step_size,
        )

        # Initialize cycle lr
        self._initialize_lr(
            self.optimizer,
            cycle_min_lr,
            cycle_max_lr,
            decay_lr_rate,
            last_batch_iteration,
        )

        # Initialize cyclic momentum
        self.cycle_momentum = cycle_momentum
        if cycle_momentum:
            self._initialize_momentum(
                self.optimizer,
                cycle_min_mom,
                cycle_max_mom,
                decay_mom_rate,
                last_batch_iteration,
            )

        # Initialize batch iteration tracker
        self.last_batch_iteration = last_batch_iteration

    # Configure cycle shape

    def _initialize_cycle(
        self,
        cycle_first_step_size,
        cycle_second_step_size,
        cycle_first_stair_count,
        cycle_second_stair_count,
        decay_step_size,
    ):
        cycle_first_step_size = float(cycle_first_step_size)
        cycle_second_step_size = (
            float(cycle_second_step_size)
            if cycle_second_step_size is not None
            else cycle_first_step_size
        )

        self.total_size = cycle_first_step_size + cycle_second_step_size
        self.step_ratio = cycle_first_step_size / self.total_size
        self.first_stair_count = cycle_first_stair_count
        self.second_stair_count = (
            cycle_first_stair_count
            if cycle_second_stair_count is None
            else cycle_second_stair_count
        )
        self.decay_step_size = decay_step_size

        if math.isclose(self.decay_step_size, 0):
            self.skip_lr_decay = True
            self.skip_mom_decay = True
        else:
            self.skip_lr_decay = False
            self.skip_mom_decay = False

    # Configure lr schedule
    def _initialize_lr(
        self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration
    ):
        self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups)
        if last_batch_iteration == -1:
            for lr, group in zip(self.min_lrs, optimizer.param_groups):
                group["lr"] = lr

        self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups)
        self.decay_lr_rate = decay_lr_rate

        if math.isclose(self.decay_lr_rate, 0):
            self.skip_lr_decay = True

    # Configure momentum schedule
    def _initialize_momentum(
        self,
        optimizer,
        cycle_min_mom,
        cycle_max_mom,
        decay_mom_rate,
        last_batch_iteration,
    ):
        if "betas" not in optimizer.defaults:
            optimizer_name = type(optimizer).__name__
            print(
                f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
            )
            self.cycle_momentum = False
            return

        self.decay_mom_rate = decay_mom_rate
        self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups)
        self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups)

        if last_batch_iteration == -1:
            for momentum, group in zip(self.min_moms, optimizer.param_groups):
                group["betas"] = momentum

        if math.isclose(self.decay_mom_rate, 0):
            self.skip_mom_decay = True

    def _get_scale_factor(self):
        batch_iteration = self.last_batch_iteration + 1
        cycle = math.floor(1 + batch_iteration / self.total_size)
        x = 1.0 + batch_iteration / self.total_size - cycle
        if x <= self.step_ratio:
            scale_factor = x / self.step_ratio
        else:
            scale_factor = (x - 1) / (self.step_ratio - 1)

        return scale_factor

    def _get_cycle_mom(self):
        scale_factor = self._get_scale_factor()
        momentums = []
        for base_betas, max_betas in zip(self.min_moms, self.max_moms):
            cycle_min_mom = base_betas[0]
            cycle_max_mom = max_betas[0]
            base_height = (cycle_max_mom - cycle_min_mom) * scale_factor
            momentum = cycle_max_mom - base_height
            momentums.append((momentum, base_betas[1]))
        return momentums

    def _get_cycle_lr(self):
        scale_factor = self._get_scale_factor()
        lrs = []
        for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
            base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
            lr = cycle_min_lr + base_height
            lrs.append(lr)

        return lrs

    def _get_decay_mom(self, decay_batch_iteration):
        if self.skip_mom_decay:
            return self.max_moms

        decay_interval = decay_batch_iteration / self.decay_step_size
        mom_decay_factor = 1 + self.decay_mom_rate * decay_interval
        momentums = [
            (beta0 * mom_decay_factor, beta1) for beta0, beta1 in self.max_moms
        ]

        return momentums

    def _get_decay_lr(self, decay_batch_iteration):
        """Calculates the learning rate at batch index. This function is used
        after the cycle completes and post cycle decaying of lr/mom is enabled.
        This function treats `self.last_batch_iteration` as the last batch index.
        """
        if self.skip_lr_decay:
            return self.min_lrs

        decay_interval = decay_batch_iteration / self.decay_step_size
        lr_decay_factor = 1 + self.decay_lr_rate * decay_interval
        lrs = [cycle_min_lr / lr_decay_factor for cycle_min_lr in self.min_lrs]

        return lrs

    def get_lr(self):
        """Calculates the learning rate at batch index. This function treats
        `self.last_batch_iteration` as the last batch index.
        """
        if self.last_batch_iteration < self.total_size:
            return self._get_cycle_lr()
        return self._get_decay_lr(self.last_batch_iteration - self.total_size + 1)

    def get_mom(self):
        """Calculates the momentum at batch index. This function treats
        `self.last_batch_iteration` as the last batch index.
        """
        if not self.cycle_momentum:
            return None

        if self.last_batch_iteration < self.total_size:
            return self._get_cycle_mom()
        return self._get_decay_mom(self.last_batch_iteration - self.total_size + 1)

    def get_last_lr(self):
        """Return last computed learning rate by current scheduler."""
        assert getattr(self, "_last_lr", None) is not None, "need to call step() first"
        return self._last_lr

    def step(self, batch_iteration=None):
        """Updates the optimizer with the learning rate for the last batch index.
        `self.last_batch_iteration` is treated as the last batch index.

        If self.cycle_momentum is true, also updates optimizer momentum.
        """
        if batch_iteration is None:
            batch_iteration = self.last_batch_iteration + 1

        self.last_batch_iteration = batch_iteration
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group["lr"] = lr
        self._last_lr = [group["lr"] for group in self.optimizer.param_groups]

        if self.cycle_momentum:
            momentums = self.get_mom()
            for param_group, momentum in zip(self.optimizer.param_groups, momentums):
                param_group["betas"] = momentum

    def state_dict(self):
        return {"last_batch_iteration": self.last_batch_iteration}

    def load_state_dict(self, sd):
        self.last_batch_iteration = sd["last_batch_iteration"]


class WarmupLR(object):
    """Increase the learning rate of each parameter group from min lr to max lr
    over warmup_num_steps steps, and then fix at max lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        warmup_min_lr (float or list): minimum learning rate. Default: 0
        warmup_max_lr (float or list): maximum learning rate. Default: 0.001
        warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
        warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
        last_batch_iteration (int): The index of the last batch. Default: -1.
    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = WarmupLR(optimizer)
        >>> data_loader = torch.utils.data.DataLoader(...)
        >>> for epoch in range(10):
        >>>     for batch in data_loader:
        >>>         train_batch(...)
        >>>         scheduler.step()

    """

    def __init__(
        self,
        optimizer: Optimizer,
        warmup_min_lr: float = 0.0,
        warmup_max_lr: float = 0.001,
        warmup_num_steps: int = 1000,
        warmup_type: str = WARMUP_LOG_RATE,
        last_batch_iteration: int = -1,
    ):

        self.optimizer = get_torch_optimizer(optimizer)

        self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr")
        self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr")
        self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)]
        self.warmup_num_steps = max(2, warmup_num_steps)
        # Currently only support linear and log function
        if warmup_type not in {WARMUP_LOG_RATE, WARMUP_LINEAR_RATE}:
            print(
                f"Using unknown warmup_type: {warmup_type}. The increasing function "
                f"is set to default (log)"
            )
            warmup_type = WARMUP_LOG_RATE
        self.warmup_type = warmup_type
        self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
        self.last_batch_iteration = last_batch_iteration

    def get_lr(self):
        if self.last_batch_iteration < 0:
            print(
                "Attempting to get learning rate from scheduler before it has started"
            )
            return [0.0]
        gamma = self._get_gamma()
        return [
            min_lr + (delta_lr * gamma)
            for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)
        ]

    def get_last_lr(self):
        """Return last computed learning rate by current scheduler."""
        assert getattr(self, "_last_lr", None) is not None, "need to call step() first"
        return self._last_lr

    def step(self, last_batch_iteration=None):
        if last_batch_iteration is None:
            last_batch_iteration = self.last_batch_iteration + 1
        self.last_batch_iteration = last_batch_iteration
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group["lr"] = lr
        self._last_lr = [group["lr"] for group in self.optimizer.param_groups]

    def state_dict(self):
        return {"last_batch_iteration": self.last_batch_iteration}

    def load_state_dict(self, sd):
        self.last_batch_iteration = sd["last_batch_iteration"]

    def _get_gamma(self):
        if self.last_batch_iteration < self.warmup_num_steps:
            if self.warmup_type == WARMUP_LOG_RATE:
                return self.inverse_log_warm_up * math.log(
                    self.last_batch_iteration + 1
                )
            elif self.warmup_type == WARMUP_LINEAR_RATE:
                return self.last_batch_iteration / self.warmup_num_steps
        return 1.0

    def _format_param(self, optimizer, param_value, param_name):
        if isinstance(param_value, list) or isinstance(param_value, tuple):
            if len(param_value) != len(optimizer.param_groups):
                raise ValueError(
                    "expected {} value for {}, got {}".format(
                        len(optimizer.param_groups),
                        param_name,
                        FileNotFoundError(param_value),
                    )
                )
            return list(param_value)
        return [param_value] * len(optimizer.param_groups)


class WarmupDecayLR(WarmupLR):
    """Increase the learning rate of each parameter group from min lr to max lr
    over warmup_num_steps steps, and then decay at linear rate over the remaining training steps.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        total_num_steps (int): total number of training steps
        warmup_min_lr (float or list): minimum learning rate. Default: 0
        warmup_max_lr (float or list): maximum learning rate. Default: 0.001
        warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
        warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
        last_batch_iteration (int): The index of the last batch. Default: -1.
    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = WarmupDecayLR(optimizer, 1000000)
        >>> data_loader = torch.utils.data.DataLoader(...)
        >>> for epoch in range(10):
        >>>     for batch in data_loader:
        >>>         train_batch(...)
        >>>         scheduler.step()

    """

    def __init__(
        self,
        optimizer: Optimizer,
        total_num_steps: int,
        warmup_min_lr: float = 0.0,
        warmup_max_lr: float = 0.001,
        warmup_num_steps: int = 1000,
        warmup_type: str = WARMUP_LOG_RATE,
        last_batch_iteration: int = -1,
    ):

        self.total_num_steps = total_num_steps
        super(WarmupDecayLR, self).__init__(
            optimizer,
            warmup_min_lr,
            warmup_max_lr,
            warmup_num_steps,
            warmup_type,
            last_batch_iteration,
        )
        if self.total_num_steps < self.warmup_num_steps:
            print(
                "total_num_steps {} is less than warmup_num_steps {}".format(
                    total_num_steps, warmup_num_steps
                )
            )

    def _get_gamma(self):
        if self.last_batch_iteration < self.warmup_num_steps:
            if self.warmup_type == WARMUP_LOG_RATE:
                return self.inverse_log_warm_up * math.log(
                    self.last_batch_iteration + 1
                )
            elif self.warmup_type == WARMUP_LINEAR_RATE:
                return self.last_batch_iteration / self.warmup_num_steps
        return max(
            0.0,
            float(self.total_num_steps - self.last_batch_iteration)
            / float(max(1.0, self.total_num_steps - self.warmup_num_steps)),
        )
